"""
Phase 8: Mediation Probe — Why Z₁ on CA Fails in the Trilemma Sample
=====================================================================
1. Full multilateral panel mediation (diagnose sample composition vs mediation)
2. Subsample mediation (eurozone, OECD floaters, non-OECD)
3. Regime-contingent CA test (Z₁ × eurozone interaction)

Key question: Is the mediation null because Z doesn't predict CA in the
trilemma sample (sample composition), or because trilemma doesn't mediate?
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

MULTILATERAL_DATA = ROOT_DIR / "multilateral" / "followup" / "data" / "processed"

# Eurozone members with join years
EUROZONE_JOIN = {
    'AUT': 1999, 'BEL': 1999, 'FIN': 1999, 'FRA': 1999, 'DEU': 1999,
    'IRL': 1999, 'ITA': 1999, 'LUX': 1999, 'NLD': 1999, 'PRT': 1999,
    'ESP': 1999, 'GRC': 2001, 'SVN': 2007, 'CYP': 2008, 'MLT': 2008,
    'SVK': 2009, 'EST': 2011, 'LVA': 2014, 'LTU': 2015,
}
EUROZONE_ISO3 = set(EUROZONE_JOIN.keys())

OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def run_panel_gls(df, y_var, x_vars, label):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    try:
        gls.fit(y, X, sub['iso3'].values, sub['year'].values)
    except Exception as e:
        print(f"  {label}: GLS failed ({e}), skipping")
        return None

    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(x_vars):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    return result


def write_table(results, filename, title, note=None):
    """Write regression results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    if note:
        lines.append(f"\n{note}")
    else:
        lines.append("\n*Panel GLS with country and year fixed effects. "
                     "Standard errors in parentheses.*")
        lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ── 1. Full Multilateral Panel Mediation ────────────────────────────

def full_panel_mediation(tri_df):
    """Merge multilateral full_panel with trilemma indices, run 4 models."""
    print("\n" + "=" * 60)
    print("1. FULL MULTILATERAL PANEL MEDIATION")
    print("=" * 60)

    # Load full panel (filter to year <= 2024 for current data)
    fp = pd.read_csv(MULTILATERAL_DATA / "full_panel.csv")
    fp = fp[fp['year'] <= 2024].copy()
    print(f"  Full panel: {len(fp)} obs, {fp['iso3'].nunique()} countries")

    # Get trilemma indices from trilemma panel
    tri_cols = ['iso3', 'year', 'mi_index', 'ers_index', 'fo_index',
                'eurozone', 'oecd_floater']
    tri_merge = tri_df[tri_cols].dropna(subset=['mi_index', 'ers_index']).copy()
    print(f"  Trilemma data available: {len(tri_merge)} obs")

    # Merge
    merged = fp.merge(tri_merge, on=['iso3', 'year'], how='left')
    merged['has_trilemma'] = merged['mi_index'].notna().astype(int)

    # Identify which obs have trilemma data
    n_with = merged['has_trilemma'].sum()
    n_without = len(merged) - n_with
    print(f"  Merged panel: {len(merged)} obs ({n_with} with trilemma, "
          f"{n_without} without)")

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']

    results = []

    # M1: Z + controls, full sample (expect Z₁ ≈ 35***)
    print("\n  --- M1: Full sample (all obs) ---")
    r = run_panel_gls(merged, 'ca_gdp', z_vars + controls, 'M1: Full')
    if r: results.append(r)

    # M2: Z + controls, restricted to obs WITH trilemma data
    print("\n  --- M2: Restricted to obs with trilemma data ---")
    restricted = merged[merged['has_trilemma'] == 1].copy()
    r = run_panel_gls(restricted, 'ca_gdp', z_vars + controls,
                      'M2: Tri sample')
    if r: results.append(r)

    # M3: Z + MI + ERS + controls, same restricted sample
    print("\n  --- M3: Z + MI + ERS + controls (trilemma sample) ---")
    r = run_panel_gls(restricted, 'ca_gdp',
                      z_vars + ['mi_index', 'ers_index'] + controls,
                      'M3: Z + Tri')
    if r: results.append(r)

    # M4: MI + ERS + controls only (no Z)
    print("\n  --- M4: Trilemma only (no Z) ---")
    r = run_panel_gls(restricted, 'ca_gdp',
                      ['mi_index', 'ers_index'] + controls,
                      'M4: Tri only')
    if r: results.append(r)

    write_table(results, "phase8_full_panel_mediation.md",
                "Full Panel Mediation: Diagnosing Sample Composition vs Mediation",
                note=("*Panel GLS with country and year fixed effects. "
                      "M1 uses full multilateral panel. M2-M4 restrict to observations "
                      "with trilemma data available. Key diagnostic: if Z₁ drops from "
                      "M1→M2, the issue is sample composition, not mediation.*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))

    # Report the key diagnostic
    print("\n  --- KEY DIAGNOSTIC ---")
    if len(results) >= 3:
        z1_m1 = results[0].get('Z_1_coef', np.nan)
        z1_m1_p = results[0].get('Z_1_p', np.nan)
        z1_m2 = results[1].get('Z_1_coef', np.nan)
        z1_m2_p = results[1].get('Z_1_p', np.nan)
        z1_m3 = results[2].get('Z_1_coef', np.nan)
        z1_m3_p = results[2].get('Z_1_p', np.nan)

        print(f"    Z₁ on CA (full sample):       {z1_m1:.4f} (p={z1_m1_p:.4f})")
        print(f"    Z₁ on CA (trilemma sample):    {z1_m2:.4f} (p={z1_m2_p:.4f})")
        print(f"    Z₁ on CA (+ MI, ERS):          {z1_m3:.4f} (p={z1_m3_p:.4f})")

        if abs(z1_m1) > 1e-8:
            comp_drop = (1 - abs(z1_m2) / abs(z1_m1)) * 100
            print(f"    Sample composition drop: {comp_drop:.1f}%")
        if abs(z1_m2) > 1e-8:
            med_atten = (1 - abs(z1_m3) / abs(z1_m2)) * 100
            print(f"    Mediation attenuation:   {med_atten:.1f}%")

    return merged


# ── 2. Subsample Mediation ──────────────────────────────────────────

def subsample_mediation(merged):
    """Run mediation separately for eurozone, OECD floaters, non-OECD."""
    print("\n" + "=" * 60)
    print("2. SUBSAMPLE MEDIATION")
    print("=" * 60)

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']

    # Create eurozone post-join filter
    ez_rows = []
    for iso3, join_yr in EUROZONE_JOIN.items():
        mask = (merged['iso3'] == iso3) & (merged['year'] >= join_yr)
        ez_rows.append(merged[mask])
    ez_df = pd.concat(ez_rows, ignore_index=True) if ez_rows else pd.DataFrame()

    # OECD floaters (non-eurozone OECD, post-1999)
    float_df = merged[(merged['iso3'].isin(OECD)) &
                      (~merged['iso3'].isin(EUROZONE_ISO3)) &
                      (merged['year'] >= 1999)].copy()

    # Non-OECD
    non_oecd = merged[~merged['iso3'].isin(OECD)].copy()

    subsamples = [
        ('Eurozone', ez_df),
        ('OECD Float', float_df),
        ('Non-OECD', non_oecd),
    ]

    all_results = []

    for label, sub_df in subsamples:
        if len(sub_df) < 50:
            print(f"\n  {label}: insufficient obs ({len(sub_df)}), skipping")
            continue

        print(f"\n  --- {label} (N={len(sub_df)}, "
              f"{sub_df['iso3'].nunique()} countries) ---")

        # Without trilemma
        r_base = run_panel_gls(sub_df, 'ca_gdp', z_vars + controls,
                               f'{label}: Z')
        # With trilemma
        tri_vars = ['mi_index', 'ers_index']
        avail_tri = [v for v in tri_vars if v in sub_df.columns and
                     sub_df[v].notna().sum() > 50]

        r_med = None
        if avail_tri:
            r_med = run_panel_gls(sub_df, 'ca_gdp',
                                  z_vars + avail_tri + controls,
                                  f'{label}: Z+Tri')

        if r_base:
            all_results.append(r_base)
        if r_med:
            all_results.append(r_med)

        # Attenuation
        if r_base and r_med:
            z1_base = r_base.get('Z_1_coef', 0)
            z1_med = r_med.get('Z_1_coef', 0)
            if abs(z1_base) > 1e-8:
                atten = (1 - abs(z1_med) / abs(z1_base)) * 100
                print(f"    {label} Z₁ attenuation: {z1_base:.4f} → "
                      f"{z1_med:.4f} ({atten:+.1f}%)")

    write_table(all_results, "phase8_subsample_mediation.md",
                "Subsample Mediation: Z → CA With/Without Trilemma Controls",
                note=("*Panel GLS with country and year fixed effects. "
                      "Each subsample tested with and without MI/ERS controls. "
                      "Attenuation indicates trilemma mediation strength.*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))


# ── 3. Regime-Contingent CA Test ────────────────────────────────────

def regime_contingent(merged):
    """Test Z₁ on CA with eurozone interaction on the full merged panel."""
    print("\n" + "=" * 60)
    print("3. REGIME-CONTINGENT CA TEST")
    print("=" * 60)

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']

    # Create eurozone dummy (post-join only)
    merged['ez_post'] = 0
    for iso3, join_yr in EUROZONE_JOIN.items():
        mask = (merged['iso3'] == iso3) & (merged['year'] >= join_yr)
        merged.loc[mask, 'ez_post'] = 1

    # Create interaction: Z_1 × eurozone
    merged['Z_1_x_eurozone'] = merged['Z_1'] * merged['ez_post']

    results = []

    # M1: Baseline Z → CA (full merged panel)
    print("\n  --- M1: Z → CA (full) ---")
    r = run_panel_gls(merged, 'ca_gdp', z_vars + controls, 'Baseline')
    if r: results.append(r)

    # M2: Z + eurozone dummy + Z₁×eurozone
    print("\n  --- M2: Z + eurozone + Z₁×eurozone ---")
    r = run_panel_gls(merged, 'ca_gdp',
                      z_vars + ['ez_post', 'Z_1_x_eurozone'] + controls,
                      'Z + EZ interact')
    if r: results.append(r)

    # M3: OECD only
    print("\n  --- M3: OECD only + Z₁×eurozone ---")
    oecd_df = merged[merged['iso3'].isin(OECD)].copy()
    r = run_panel_gls(oecd_df, 'ca_gdp',
                      z_vars + ['ez_post', 'Z_1_x_eurozone'] + controls,
                      'OECD + EZ int')
    if r: results.append(r)

    # M4: Post-1990 only
    print("\n  --- M4: Post-1990 + Z₁×eurozone ---")
    post90 = merged[merged['year'] >= 1990].copy()
    r = run_panel_gls(post90, 'ca_gdp',
                      z_vars + ['ez_post', 'Z_1_x_eurozone'] + controls,
                      'Post-1990 + EZ')
    if r: results.append(r)

    write_table(results, "phase8_regime_contingent.md",
                "Regime-Contingent CA Effect: Z₁ × Eurozone Interaction",
                note=("*Panel GLS with country and year fixed effects. "
                      "ez_post = 1 for eurozone members after their accession year. "
                      "Z_1_x_eurozone tests whether the demographic CA effect "
                      "differs for eurozone members (who lack exchange rate adjustment).*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))


# ── Main ─────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 8: MEDIATION PROBE")
    print("=" * 70)

    tri_df = pd.read_csv(DATA / "trilemma_panel.csv")
    print(f"Trilemma panel: {len(tri_df)} obs, {tri_df['iso3'].nunique()} countries")

    # 1. Full panel mediation
    merged = full_panel_mediation(tri_df)

    # 2. Subsample mediation
    subsample_mediation(merged)

    # 3. Regime-contingent CA test
    regime_contingent(merged)

    print("\n" + "=" * 70)
    print("PHASE 8 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
